#!/usr/bin/env python2.7
#-*- coding:utf-8 -*-
#Author='He Rensheng'
#Email='hrs323@126.com'

'''
These methods are used in exporting pytorch model to caffe2 model.
These are revised by me based on Pytorch tutorials:
https://github.com/onnx/tutorials/blob/master/tutorials/PytorchCaffe2MobileSqueezeNet.ipynb

Revision 1.
   import caffe2.python.onnx.backend  --> import onnx_caffe2.backend
Revision 2.
   from caffe2.python.onnx.backend import Caffe2Backend as c2
        --> from onnx_caffe2.backend import Caffe2Backend as c2

Reason of revisions:
    In "init_net, predict_net = c2.onnx_graph_to_caffe2_net(model.graph)", it returns a error
        ---AttributeError: 'GraphProto' object has no attribute 'graph'

'''

import numpy as np
import torch
from torch.autograd import Variable
import onnx
import onnx_caffe2.backend
from onnx_caffe2.backend import Caffe2Backend as c2
from caffe2.python import workspace


def pytorch2onnx(torch_model, onnx_model_path):
    batch_size = 1    # just a random number
    # Input to the model
    x = Variable(torch.randn(batch_size, 3, 224, 224), requires_grad=True)
    # Export the model
    torch_out = torch.onnx._export(torch_model,             # model being run
                                   x,                       # model input (or a tuple for multiple inputs)
                                   onnx_model_path,       # where to save the model (can be a file or file-like object)
                                   export_params=True)       # store the trained parameter weights inside the model file
    return torch_out, x


def onnx2caffe2(onnx_path, torch_out, x, caffe2_init_path, caffe2_predict_path):
    # Load the ONNX GraphProto object. Graph is a standard Python protobuf object
    model = onnx.load(onnx_path)
    onnx.checker.check_model(model)
    onnx.helper.printable_graph(model.graph)

    # prepare the caffe2 backend for executing the model this converts the ONNX graph into a
    # Caffe2 NetDef that can execute it. Other ONNX backends, like one for CNTK will be
    # availiable soon.
    prepared_backend = onnx_caffe2.backend.prepare(model)

    # run the model in Caffe2

    # Construct a map from input names to Tensor data.
    # The graph itself contains inputs for all weight parameters, followed by the input image.
    # Since the weights are already embedded, we just need to pass the input image.
    # last input the grap
    W = {model.graph.input[0].name: x.data.numpy()}

    # Run the Caffe2 net:
    c2_out = prepared_backend.run(W)[0]

    # Verify the numerical correctness up to 3 decimal places
    np.testing.assert_almost_equal(torch_out.data.cpu().numpy(), c2_out, decimal=3)

    # Export to mobile
    init_net, predict_net = c2.onnx_graph_to_caffe2_net(model.graph)
    with open(caffe2_init_path, "wb") as f:
        f.write(init_net.SerializeToString())
    with open(caffe2_predict_path, "wb") as f:
        f.write(predict_net.SerializeToString())

    # Verify it runs with predictor
    verify(caffe2_init_path, caffe2_predict_path)


def verify(caffe2_init_path, caffe2_predict_path):
    with open(caffe2_init_path) as f:
        init_net = f.read()
    with open(caffe2_predict_path) as f:
        predict_net = f.read()
    p = workspace.Predictor(init_net, predict_net)
    # The following code should run:
    try:
        img = np.random.rand(1, 3, 224, 224).astype(np.float32)
        result, = p.run([img])
    except Exception as e:
        print('Unexpected Error: {}'.format(e))
        return
    finally:
        print('')
        print('Congratulations!')
        print('The model works. Our model can produces prediction '
              'for each of ImageNet {} classes'.format(result.shape[1]))
        print('')
        print("Next, Copy output/caffe2/[squeeze_init_net.pb, squeeze_predict_net.pb] "
              "to AICamera/app/src/main/assets. ")
        print("Now we can open Android Studio and import the AICamera project, "
              "run the app by clicking the green play button.")
        print('')
        print("You can check Caffe2 AI Camera tutorial for more details of how "
              "Caffe2 can be invoked in the Android mobile app.")
        print("Caffe2 AI Camera tutorial URL is https://caffe2.ai/docs/AI-Camera-demo-android.html")